library(ggplot2)
library(ggthemes)
library(here)
library(rwebppl)
library(svglite)
library(tidyverse)
source(here("utils.R"))
source(here("experiments", "blocksworld-main", "analysis", "model-utils.R"))


# for knitting, data must be loaded in chunk where libraries are loaded
# -- TODO: find out why!!

# load behavioral data from experiments
fn <- here("experiments", "blocksworld-prior", "results", "data-processed",
           "tables_experimental.csv")
dat.experiment1 <- read_csv(fn)

fn <- here("experiments", "blocksworld-main", "results", "data-processed", 
           "data_experimental_means.csv")
dat.experiment2.means <- read_csv(fn) %>%
  mutate(utterance=factor(utterance)) %>%
  add_column(produced_by = "human")

# set target dir
TARGET_DIR <- here("experiments", "blocksworld-main", "results", "model")
dir.create(TARGET_DIR, recursive = TRUE, showWarnings = FALSE)

setup to run model

dat.to_wppl <- list()
tables_to_wppl <- prepare_tables(dat.experiment1)

dat.to_wppl$stimulus_ids = tables_to_wppl %>% pull(stimulus_id) %>% unique()
dat.to_wppl$tables <- tables_to_wppl

# check for utterances with highest theta
# (are there any that are not covered by states?)
dat.utts <- pids_to_cover_utts(dat.experiment1, 0.9)
stopifnot(dat.utts$utts_not_covered %>% length == 0);

# dat.to_wppl$par_dirichlet = read_csv("results-dirichlet.csv") %>%
#                           add_column(alpha=list(alpha_1, alpha_2, alpha_3, alpha_4))
#                       nest("alpha_1", "alpha_2", "alpha_3", "alpha_4")  %>%
#                       select(stimulus_id, alpha) %>%
#                       pivot_wider(names_from = stimulus_id, values_from = alpha) %>% unnest()

Empirical vs. Model

dat.to_wppl$theta <- 0.72
dat.to_wppl$alpha <- 2
dat.to_wppl$cost_conditional <- 0.1
dat.to_wppl$cost_and <- 7
dat.to_wppl$cost_likely <- 7
dat.to_wppl$cost_neg <- 7

now <- str_replace_all(Sys.time(), "[\\s+:-]", "_")
target_dir <- paste(TARGET_DIR, now,sep=.Platform$file.sep)
dir.create(target_dir, showWarnings = FALSE, recursive = TRUE)

Model data

# dat.to_wppl$utterances <- all_utterances[which(!str_detect(all_utterances, "likely.*"))]
model_predictions <- run_model(dat.to_wppl)
model <- model_predictions$avg %>% 
  mutate(utterance=factor(utterance, levels = all_utterances))

save_params(dat.to_wppl, target_dir)
save_predictions(model, target_dir, "model.csv")

Merge Model (unnormalized) and empirical data

# use participants mean responses
data <- bind_rows(model, dat.experiment2.means) %>% arrange(stimulus_id)

# correlation with unnormed model data
data.compare <- data %>%
  pivot_wider(names_from = produced_by, values_from = mean) %>% 
  filter(!is.na(human))

Make plots and compute correlation

df <- data %>%
  mutate(utterance=factor(utterance, levels = all_utterances))
# stimuli <- c("S34-806")
stimuli <- dat.to_wppl$stimulus_ids

correlations <- tibble()
for(s in stimuli) {
  dat.cor <- data.compare %>% filter(stimulus_id == s)
  val.cor <- cor.test(dat.cor$human, dat.cor$model)
  correlations <- bind_rows(correlations,
    tibble(cor=val.cor$estimate[[1]], p=val.cor$p.value,
           cor_null=val.cor$null.value[[1]], stimulus_id=s))
  
  write_rds(correlations, paste(target_dir, "correlations.rds",
                                sep=.Platform$file.sep))
  
  p <-  df %>% filter(stimulus_id == s & mean != 0) %>%
    ggplot(aes(x=utterance, y=mean, fill=produced_by))  +
    geom_bar(position="dodge", stat="identity") +
    labs(y="probability") +
    theme_classic(base_size = 30) + 
    # ggtitle(paste(s, " (corr.", round(val.cor$estimate[[1]], 2), ")", sep="")) +
    theme(legend.position="top", axis.text.x = element_text(angle=60, hjust=1),
          axis.text.y=element_text(size=20))
  
  print(p)
  fn <- paste(target_dir,.Platform$file.sep,"predictions-both-",s,".png",sep="")
  ggsave(fn, p, width=10, height=6)
      
    
  p <-  df %>% filter(produced_by == "human" & stimulus_id == s & mean != 0) %>%
    ggplot(aes(x=utterance, y=mean, fill=produced_by))  +
    geom_bar(position="dodge", stat="identity") +
    labs(y="probability (mean response)") +
    theme_classic(base_size=25) + 
    ggtitle(paste(s, " (corr.", round(val.cor$estimate[[1]], 2), ")", sep="")) +
    theme(legend.position="none", axis.text.x = element_text(angle=60, hjust=1))
  
  print(p)
  ggsave(paste(target_dir,.Platform$file.sep,"human-means-", s, ".png", sep=""),
         p, width=10, height=6)
    
  p <-  df %>% filter(produced_by == "model" & stimulus_id == s & mean != 0) %>%
    ggplot(aes(x=utterance, y=mean, fill=produced_by))  +
    geom_bar(position="dodge", stat="identity") +
    labs(y="probability (mean model)") +
    theme_classic(base_size=25) + 
    ggtitle(paste(s, " (corr.", round(val.cor$estimate[[1]], 2), ")", sep="")) +
    theme(legend.position="none", axis.text.x = element_text(angle=60, hjust=1))
  
  print(p)
  ggsave(paste(target_dir,.Platform$file.sep, "model-means-", s, ".png", sep=""),
         p, width=10, height=6)
}  

correlations %>% arrange(cor)
val.cor <- cor.test(data.compare$human, data.compare$model)
val.cor
## 
##  Pearson's product-moment correlation
## 
## data:  data.compare$human and data.compare$model
## t = 11.158, df = 98, p-value < 2.2e-16
## alternative hypothesis: true correlation is not equal to 0
## 95 percent confidence interval:
##  0.6466075 0.8234545
## sample estimates:
##       cor 
## 0.7480227
write_rds(val.cor, paste(target_dir, "overall-correlation.rds",
                         sep=.Platform$file.sep))

Rearrange data

Use one utterance as reference, and plot the probability put on other 3 utterances with respect to reference utterance in order to get comparable plots.

data.ref <- data %>% group_by(stimulus_id, produced_by) %>%
  filter(utterance %in% c("A", "C", "A > C", "C > A")) %>% 
  pivot_wider(names_from = utterance, values_from = mean) %>% 
  mutate(ref=A) %>% 
  pivot_longer(cols=c("A", "C", "A > C", "C > A"), names_to = "utterance",
               values_to = "mean") %>% 
  mutate(mean.ref=if_else(utterance=="A", mean, mean/ref)) %>% 
  mutate(mean.ref=round(mean.ref, 1), 
         mean = round(mean, 3))

Make plots for model/human estimates relative to reference utterance

data.ref <- data.ref %>% mutate(lab.ref=as.character(mean.ref),
                                lab=paste("(", as.character(mean), ")", sep="")) %>% 
  unite("lab", lab.ref, lab, sep=" ")

for(s in stimuli) {
  data.s <- data.ref %>% filter(stimulus_id == s) %>% 
      mutate(utterance=factor(utterance, levels = all_utterances))
  dat.cor <- data.compare %>% filter(stimulus_id == s)
  val.cor <- cor.test(dat.cor$human, dat.cor$model)
  
  p <- data.s %>%
    ggplot(aes(x=utterance, y=mean.ref, fill=produced_by))  +
    geom_bar(position="dodge", stat="identity") +
    geom_text(data=data.s, aes(x=utterance, y=mean.ref, label=lab), vjust=0,
              position = position_dodge(width = 1)) +
    facet_wrap(~stimulus_id) %>%
    labs(y="prediction wrt pred. for A") +
    theme(legend.position="bottom", axis.text.x = element_text(angle=60, hjust=1)) +
    ggtitle(paste(s, " (corr.", round(val.cor$estimate[[1]], 2), ")", sep=""))
  
  print(p)
  ggsave(paste(target_dir, .Platform$file.sep, "comparison-", s, ".png", sep=""),
         p, width=10, height=6)

}

correlations %>% filter(cor < 0.6) %>% arrange(cor)

Plot data as scatter plot

data.scatter <- data.ref %>% select(-ref, -mean.ref, -lab) %>% 
  group_by(stimulus_id, utterance) %>% 
  pivot_wider(names_from = produced_by, values_from = mean) %>% 
  arrange(stimulus_id)

for(s in stimuli) {
  p <- data.scatter %>% filter(stimulus_id == s) %>% 
    ggplot(aes(x=model, y=human)) +
    geom_point() +
    geom_smooth(method='lm')+
    ggtitle(s)

  print(p)
  fn <- paste(target_dir, .Platform$file.sep, "model-vs-human-scatter-", s,
              ".png", sep="")
  ggsave(fn, p, width=10, height=6)
}